/*
* drvAESDMA.c- Sigmastar
*
* Copyright (C) 2018 Sigmastar Technology Corp.
*
* Author: nick.lin <nick.lin@sigmastar.com.tw>
*
* This software is licensed under the terms of the GNU General Public
* License version 2, as published by the Free Software Foundation, and
* may be copied, distributed, and modified under those terms.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
* GNU General Public License for more details.
*
*/


#ifndef _DRV_AESDMA_H_
#include "drvAESDMA.h"
#endif
#include <common.h>
// added COCOA-1683: Verification of RTOS/kernel images [1/2]
#include "asm/arch/mach/io.h"
#include "asm/arch/mach/platform.h"

#define DBLEN   223

extern void invalidate_dcache_range(unsigned long start, unsigned long stop);

void MDrv_AESDMA_Run(aesdmaConfig *pConfig)
{
    HAL_AESDMA_Reset();

    HAL_AESDMA_SetFileinAddr(pConfig->u32SrcAddr);
    HAL_AESDMA_SetXIULength(pConfig->u32Size);
    HAL_AESDMA_SetFileoutAddr(pConfig->u32DstAddr, pConfig->u32Size);

    switch(pConfig->eKeyType)
    {
    case E_AESDMA_KEY_CIPHER:
        HAL_AESDMA_UseCipherKey();
        HAL_AESDMA_SetCipherKey(pConfig->pu16Key);
        break;
    case E_AESDMA_KEY_EFUSE:
        HAL_AESDMA_UseEfuseKey();
        break;
    case E_AESDMA_KEY_HW:
        HAL_AESDMA_UseHwKey();
        break;
    default:
        return;
    }

    if(pConfig->bDecrypt)
    {
        HAL_AESDMA_CipherDecrypt();
    }
    else
    {
        HAL_AESDMA_CipherEncrypt();
    }

    if(pConfig->bSetIV)
    {
        HAL_AESDMA_SetIV(pConfig->pu16IV);
    }

    HAL_AESDMA_Enable();

    switch(pConfig->eChainMode)
    {
    case E_AESDMA_CHAINMODE_ECB:
        HAL_AESDMA_SetChainModeECB();
        HAL_AESDMA_SetXIULength(((pConfig->u32Size+15)/16)*16); // ECB mode size should align 16byte
        break;
    case E_AESDMA_CHAINMODE_CTR:
        HAL_AESDMA_SetChainModeCTR();
        HAL_AESDMA_CipherEncrypt();  // CTR mode can't set cipher_decrypt bit
        break;
    case E_AESDMA_CHAINMODE_CBC:
        HAL_AESDMA_SetChainModeCBC();
        break;
    default:
        return;
    }

    HAL_AESDMA_FileOutEnable(1);

    HAL_AESDMA_Start(1);

    // Wait for ready.
    while((HAL_AESDMA_GetStatus() & AESDMA_CTRL_DMA_DONE) != AESDMA_CTRL_DMA_DONE);
    HAL_AESDMA_Reset();

}

void MDrv_SHA_Run(U32 u32SrcAddr, U32 u32Size, enumShaMode eMode, U16 *pu16Output)
{
    HAL_SHA_Reset();

    HAL_SHA_SetAddress(u32SrcAddr);
    HAL_SHA_SetLength(u32Size);

    switch(eMode)
    {
    case E_SHA_MODE_1:
        HAL_SHA_SelMode(0);
        break;
    case E_SHA_MODE_256:
        HAL_SHA_SelMode(1);
        break;
    default:
        return;
    }

    HAL_SHA_Start();

    // Wait for the SHA done.
    while((HAL_SHA_GetStatus() & SHARNG_CTRL_SHA_READY) != SHARNG_CTRL_SHA_READY);

    HAL_SHA_Out((U32)pu16Output);

    HAL_SHA_Clear();
    HAL_SHA_Reset();
}

#ifdef CONFIG_SRAM_DUMMY_ACCESS_RSA
__attribute__((optimize("O0"))) void MDrv_RSA_Run(rsaConfig *pConfig)
#else
void MDrv_RSA_Run(rsaConfig *pConfig)
#endif

{
    int nOutSize;
    int i;

    HAL_RSA_Reset();

    HAL_RSA_SetKeyLength((pConfig->u32KeyLen-1) & 0x3F);
    HAL_RSA_SetKeyType(pConfig->bHwKey, pConfig->bPublicKey);

    HAL_RSA_Ind32Ctrl(0);
    HAL_RSA_Ind32Ctrl(1);
    HAL_RSA_LoadSignInverse_2byte((U16*)pConfig->pu32Sig);

    if (!pConfig->bHwKey)
    {
        if (pConfig->pu32KeyN)
        {
            //HAL_RSA_LoadKeyN(pConfig->pu32KeyN);
            HAL_RSA_LoadKeyNInverse(pConfig->pu32KeyN);
        }

        HAL_RSA_LoadKeyE(pConfig->pu32KeyE);

    }

#if 0
    if((!pConfig->bHwKey) && (pConfig->pu32KeyN))
    {
        HAL_RSA_LoadKeyNInverse(pConfig->pu32KeyN);
    }
    if((!pConfig->bHwKey) && (pConfig->pu32KeyE))
    {
        HAL_RSA_LoadKeyEInverse(&pConfig->pu32KeyE);
    }
#endif

    HAL_RSA_ExponetialStart();

    while((HAL_RSA_GetStatus() & RSA_STATUS_RSA_DONE) != RSA_STATUS_RSA_DONE);

    if((pConfig->bHwKey) || (pConfig->u32KeyLen == 2048))
    {
        nOutSize = (2048/8)/4;
    }
    else
    {
        nOutSize = (1024/8)/4;
    }

    HAL_RSA_Ind32Ctrl(0);
    for(i = 0; i < nOutSize; i++)
    {
        HAL_RSA_SetFileOutAddr(i);
        HAL_RSA_FileOutStart();
        if (COCOA_BLUE_SIGNATURE_VERIFICATION_PKCS)
        {
            *(pConfig->pu32Output+i) = HAL_RSA_FileOut();
        }
        else if (COCOA_BLUE_SIGNATURE_VERIFICATION_PSS)
        {
            *(pConfig->pu32Output + 63 - i) = HAL_RSA_FileOut();
        }
    }
    HAL_RSA_FileOutEnd();
    HAL_RSA_Reset();
}

void runDecrypt(U32 u32ImageAddr, U32 u32ImageSize, U16* pu16Key)
{
    aesdmaConfig config={0};

    invalidate_dcache_range(u32ImageAddr, ((u32ImageAddr + u32ImageSize) & ~(0x00000040 - 1)));

    config.u32SrcAddr=u32ImageAddr;
    config.u32DstAddr=u32ImageAddr;
    config.u32Size=u32ImageSize;
    if (pu16Key == 0 || *pu16Key == 0)
    {
        //config.eKeyType=E_AESDMA_KEY_EFUSE;  //default use EFUSE key
        config.eKeyType=E_AESDMA_KEY_HW;
    }
    else
    {
        config.eKeyType=E_AESDMA_KEY_CIPHER;
        config.pu16Key = pu16Key;
    }

    config.bDecrypt=1;
    config.eChainMode=E_AESDMA_CHAINMODE_ECB;  //default use ECB mode
    MDrv_AESDMA_Run(&config);
}

BOOL runSHA256(U32 u32SrcAddr, U32 len, U32 *sha_out)
{
    OUTREG8(GET_REG_ADDR(REG_ADDR_BASE_CLKGEN, CKG_AESDMA), CKG_AESDMA_ENABLE);
    MDrv_SHA_Run(u32SrcAddr, len, E_SHA_MODE_256, (U16*)sha_out);
    return TRUE;
}

BOOL runRSA(U32 u32SigAddr, U32 *sha_sum, U32 *pu32Key)
{
    U32 rsa_out[64];
    rsaConfig config={0};
    U32 eValue = DEFAULT_EXP_VALUE;

    OUTREG8(GET_REG_ADDR(REG_ADDR_BASE_CLKGEN, CKG_AESDMA), CKG_AESDMA_ENABLE);

    if(pu32Key == 0 || *pu32Key == 0)
    {
        config.bHwKey=1;
    }
    else
    {
        config.pu32KeyN = pu32Key;
        config.pu32KeyE = eValue;
        config.u32KeyLen = 2048;
    }
    config.bPublicKey = 1;
    config.pu32Sig = (U32*)(u32SigAddr);
    config.pu32Output = rsa_out;

    MDrv_RSA_Run(&config);

    BOOL verified = FALSE;
    if (COCOA_BLUE_SIGNATURE_VERIFICATION_PKCS)
    {
        verified = pkcs_verify(sha_sum, rsa_out);
    }
    else if (COCOA_BLUE_SIGNATURE_VERIFICATION_PSS)
    {
        verified = pss_verify(sha_sum, rsa_out);
    }
    if(!verified)
    {
        printf("[U-Boot] RSA check failed.\n");
        return FALSE;
    }

    printf("%s: compare RSA address --> MDrv_RSA_Run done\n", __func__);
    return verified;
}

BOOL runAuthenticate(U32 u32ImageAddr, U32 u32ImageSize, U32* pu32Key)
{
    U32 sha_out[8];

    OUTREG8(GET_REG_ADDR(REG_ADDR_BASE_CLKGEN, CKG_AESDMA), CKG_AESDMA_ENABLE);

    MDrv_SHA_Run(u32ImageAddr, u32ImageSize, E_SHA_MODE_256, (U16*)sha_out); //image + SHA-256
    return runRSA(u32ImageAddr + u32ImageSize, sha_out, pu32Key);
}

/**
    @param[in] data0 The first data array to be compared
    @param[in] data1 The second data array to be compared
    @param[in] u8Size The size (byte) of comparison

    @return TRUE for the same, FLASE for not the same
*/
BOOL data_compare(U8* data0, U8* data1, U8 u8Size)
{
    U8 i, u8cmpOKcnt = 0;
    for (i = 0; i < u8Size; i++)
    {
        // add random delays to upset glitching attacks. ~1ns per instruction @ 1GHz so even 1us is massive
        srand(get_ticks() ^ rand());
        int del = rand() & 0xF;
        udelay(del);
        if (data0[i] == data1[i])
            u8cmpOKcnt++;
        else
            return FALSE;
    }

    if (u8cmpOKcnt == u8Size) // make sure we didn't skip any comparisons
        return TRUE;

    return FALSE;
}

/**
    @param[in] pu8Dig The SHA256 digest of the being verified image, 32 byte.
    @param[in] pu8Sig The RSA decoded signature, 256 byte. It's expected to
               be in the following format:
               | 0x00 | 0x01 | PS (221byte) | 0x00 | Digest |

    @return TRUE for valid, FLASE for invalid
*/
BOOL pkcs_verify(U8* pu8Dig, U8* pu8Sig)
{
    return data_compare(pu8Dig, pu8Sig, 32);
}

/**
    @param[in] pu8Dig The SHA256 digest of the being verified image, 32 byte.
    @param[in] pu8Sig The RSA decoded signature, 256 byte. It's expected to
               be in the following format:
               | MaskedDb (223 byte) | Seed (32 byte) | 0xbc |

    @return TRUE for valid, FLASE for invalid
*/
BOOL pss_verify(U8* pu8Dig, U8* pu8Sig)
{
    U8 i, u8SaltLen = 0, u8EmLen, au8DbMask[DBLEN], au8Em[32];
    U8 *pu8MgfSeed, *pu8MaskedDb;
    U8 au8ShaSrc[256] __attribute__((aligned(16)));

    if(*(pu8Sig + 255) != 0xbc)
        return FALSE;

    pu8MgfSeed = pu8Sig + DBLEN;
    mgf(pu8MgfSeed, au8DbMask, DBLEN);
    pu8MaskedDb = pu8Sig;

    for(i = 0; i < DBLEN; i++)
    {
        au8DbMask[i] ^= *(pu8MaskedDb + i);
        if(!u8SaltLen && au8DbMask[i] == 0x01)
            u8SaltLen = DBLEN - i - 1;
    }

    if(u8SaltLen)
        memcpy((void*)(au8ShaSrc + 40), (void*)(&au8DbMask[DBLEN - u8SaltLen]), u8SaltLen);

    memset(au8ShaSrc, 0x00, 8);
    memcpy((void*)(au8ShaSrc + 8), pu8Dig, 32);
    u8EmLen = 40 + u8SaltLen;

    flush_dcache_range((U32)au8ShaSrc, (U32)au8ShaSrc + 128);
    chip_flush_miu_pipe();
    MDrv_SHA_Run((U32)au8ShaSrc, (U32)u8EmLen, E_SHA_MODE_256, (U16*)au8Em);

    return data_compare(pu8MgfSeed, au8Em, 32);
}

/** mgf - Mask Generation Function.
    @param[in]  pu8Seed Used to generate the mask.
    @param[out] pu8MaskDb Mask.
    @param[in]  u8MaskLen the size of the mask.

    @return Nothing
*/
void mgf(U8* pu8Seed , U8* pu8MaskDb, U8 u8MaskLen)
{
    U8 i, u8Round, au8ShaOut[32];
    U8 au8ShaSrc[36] __attribute__((aligned(16)));

    u8Round = (u8MaskLen + 31) >> 5;
    memcpy((void*)au8ShaSrc, (void*)pu8Seed, 32);
    memset(au8ShaSrc + 32, 0x00, 3);

    for(i = 0; i < u8Round; i++)
    {
        au8ShaSrc[35] = i;

        flush_dcache_range((U32)au8ShaSrc, (U32)au8ShaSrc + 64);
        chip_flush_miu_pipe();
        MDrv_SHA_Run((U32)au8ShaSrc, 36, E_SHA_MODE_256, (U16*)au8ShaOut);

        if (u8MaskLen > 32)
        {
            memcpy((void*)pu8MaskDb, (void*)au8ShaOut, 32);
            u8MaskLen -= 32;
            pu8MaskDb += 32;
        }
    }

    memcpy((void*)pu8MaskDb, (void*)au8ShaOut, u8MaskLen);
}
